{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "from arctic import Arctic\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = Arctic('localhost')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "lib = a['fund_holding']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "holding=pd.concat([lib.read(sym).data.assign(symbol=sym) for sym in lib.list_symbols()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "holding = holding.query('holding > 0.0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "holding.index.name='ts_code'\n",
    "holding = holding.reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_funds = holding.symbol.unique()\n",
    "all_stocks = holding.ts_code.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "stock_to_id = {s:i for i, s in enumerate(all_stocks)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Build bipartite graph of stock-fund according to fund holdings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = nx.Graph()\n",
    "g.add_nodes_from(all_funds, bipartite='fund')\n",
    "g.add_nodes_from(all_stocks, bipartite='stock')\n",
    "w_edges = holding[['symbol', 'ts_code', 'holding']].to_records(index=False).tolist()\n",
    "g.add_weighted_edges_from(w_edges)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "generate random walks on fund-stock graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "def choose(g, fund_or_stock):\n",
    "    nodes = list(g[fund_or_stock])\n",
    "    wghts = np.array([d['weight'] for _, d in g[fund_or_stock].items()])\n",
    "    normed_wghts = wghts / wghts.sum()\n",
    "    node = np.random.choice(nodes, p=normed_wghts)\n",
    "    return node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = []\n",
    "for i in range(1000):\n",
    "    stock = np.random.choice(all_stocks)\n",
    "    path = [stock_to_id[stock]]\n",
    "    for t in range(200):\n",
    "        # choose a random fund connected by stock\n",
    "        fund = choose(g, stock)\n",
    "        #print('{}->{}->'.format(stock, fund))\n",
    "        # choose a random stock connected by fund\n",
    "        stock = choose(g, fund)\n",
    "\n",
    "        path.append(stock_to_id[stock])\n",
    "    paths.append(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = np.array(paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_target(path, idx, window_size=5):\n",
    "    \"\"\"\n",
    "    Get a list of stocks around a stock in `path`\n",
    "    \"\"\"\n",
    "    R = np.random.randint(1, window_size+1)\n",
    "    start = idx - R if (idx - R) > 0 else 0\n",
    "    stop = idx + R\n",
    "    target = set(path[start:idx])\n",
    "    target.update(path[idx+1:stop+1])\n",
    "    \n",
    "    return list(target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_batches = []\n",
    "y_batches = []\n",
    "for path in paths:\n",
    "    for i in range(len(path)):\n",
    "        x = path[i]\n",
    "        y = get_target(path, i, 5)\n",
    "        y_batches.extend(y)\n",
    "        x_batches.extend([x]*len(y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([613, 613, 613, 75, 75, 224, 224, 224, 224, 224],\n",
       " [224, 75, 166, 224, 613, 224, 613, 166, 75, 13])"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_batches[:10], y_batches[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = tf.placeholder(tf.int32, [None], name='inputs')\n",
    "labels = tf.placeholder(tf.int32, [None, None], name='labels')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_stocks = len(all_stocks)\n",
    "n_embedding = 32 # Number of embedding features \n",
    "embedding = tf.Variable(tf.random_uniform((n_stocks, n_embedding), -1, 1))\n",
    "embed = tf.nn.embedding_lookup(embedding, inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of negative labels to sample\n",
    "n_sampled = 100\n",
    "softmax_w = tf.Variable(tf.truncated_normal((n_stocks, n_embedding), stddev=0.1))\n",
    "softmax_b = tf.Variable(tf.zeros(n_stocks))\n",
    "\n",
    "# Calculate the loss using negative sampling\n",
    "loss = tf.nn.sampled_softmax_loss(softmax_w, softmax_b, \n",
    "                                  labels, embed,\n",
    "                                  n_sampled, n_stocks)\n",
    "\n",
    "cost = tf.reduce_mean(loss)\n",
    "optimizer = tf.train.AdamOptimizer().minimize(cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2 Iteration: 100 Avg. Training loss: 5.3346 0.0054 sec/batch\n",
      "Epoch 1/2 Iteration: 200 Avg. Training loss: 5.3439 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 300 Avg. Training loss: 5.2803 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 400 Avg. Training loss: 5.1203 0.0031 sec/batch\n",
      "Epoch 1/2 Iteration: 500 Avg. Training loss: 5.1350 0.0032 sec/batch\n",
      "Epoch 1/2 Iteration: 600 Avg. Training loss: 4.9596 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 700 Avg. Training loss: 4.7762 0.0033 sec/batch\n",
      "Epoch 1/2 Iteration: 800 Avg. Training loss: 4.6565 0.0031 sec/batch\n",
      "Epoch 1/2 Iteration: 900 Avg. Training loss: 4.4653 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 1000 Avg. Training loss: 4.4613 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 1100 Avg. Training loss: 4.3459 0.0037 sec/batch\n",
      "Epoch 1/2 Iteration: 1200 Avg. Training loss: 4.3773 0.0034 sec/batch\n",
      "Epoch 1/2 Iteration: 1300 Avg. Training loss: 4.2550 0.0032 sec/batch\n",
      "Epoch 1/2 Iteration: 1400 Avg. Training loss: 4.2175 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 1500 Avg. Training loss: 4.2088 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 1600 Avg. Training loss: 4.2001 0.0031 sec/batch\n",
      "Epoch 1/2 Iteration: 1700 Avg. Training loss: 4.1509 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 1800 Avg. Training loss: 4.1315 0.0033 sec/batch\n",
      "Epoch 1/2 Iteration: 1900 Avg. Training loss: 4.0592 0.0034 sec/batch\n",
      "Epoch 1/2 Iteration: 2000 Avg. Training loss: 4.1054 0.0033 sec/batch\n",
      "Epoch 1/2 Iteration: 2100 Avg. Training loss: 4.0008 0.0033 sec/batch\n",
      "Epoch 1/2 Iteration: 2200 Avg. Training loss: 4.0128 0.0034 sec/batch\n",
      "Epoch 1/2 Iteration: 2300 Avg. Training loss: 4.0282 0.0036 sec/batch\n",
      "Epoch 1/2 Iteration: 2400 Avg. Training loss: 4.0624 0.0038 sec/batch\n",
      "Epoch 1/2 Iteration: 2500 Avg. Training loss: 4.0107 0.0033 sec/batch\n",
      "Epoch 1/2 Iteration: 2600 Avg. Training loss: 4.0178 0.0037 sec/batch\n",
      "Epoch 1/2 Iteration: 2700 Avg. Training loss: 3.9757 0.0034 sec/batch\n",
      "Epoch 1/2 Iteration: 2800 Avg. Training loss: 3.9492 0.0034 sec/batch\n",
      "Epoch 1/2 Iteration: 2900 Avg. Training loss: 3.9486 0.0033 sec/batch\n",
      "Epoch 1/2 Iteration: 3000 Avg. Training loss: 3.9127 0.0036 sec/batch\n",
      "Epoch 1/2 Iteration: 3100 Avg. Training loss: 3.9124 0.0035 sec/batch\n",
      "Epoch 1/2 Iteration: 3200 Avg. Training loss: 3.8779 0.0033 sec/batch\n",
      "Epoch 1/2 Iteration: 3300 Avg. Training loss: 3.9269 0.0036 sec/batch\n",
      "Epoch 1/2 Iteration: 3400 Avg. Training loss: 3.9506 0.0037 sec/batch\n",
      "Epoch 1/2 Iteration: 3500 Avg. Training loss: 3.9224 0.0033 sec/batch\n",
      "Epoch 1/2 Iteration: 3600 Avg. Training loss: 3.9056 0.0033 sec/batch\n",
      "Epoch 1/2 Iteration: 3700 Avg. Training loss: 3.8675 0.0034 sec/batch\n",
      "Epoch 1/2 Iteration: 3800 Avg. Training loss: 3.9044 0.0033 sec/batch\n",
      "Epoch 1/2 Iteration: 3900 Avg. Training loss: 3.8757 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 4000 Avg. Training loss: 3.9049 0.0031 sec/batch\n",
      "Epoch 1/2 Iteration: 4100 Avg. Training loss: 3.9407 0.0035 sec/batch\n",
      "Epoch 1/2 Iteration: 4200 Avg. Training loss: 3.8648 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 4300 Avg. Training loss: 3.8685 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 4400 Avg. Training loss: 3.8400 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 4500 Avg. Training loss: 3.8681 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 4600 Avg. Training loss: 3.8646 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 4700 Avg. Training loss: 3.8713 0.0035 sec/batch\n",
      "Epoch 1/2 Iteration: 4800 Avg. Training loss: 3.8325 0.0034 sec/batch\n",
      "Epoch 1/2 Iteration: 4900 Avg. Training loss: 3.9154 0.0032 sec/batch\n",
      "Epoch 1/2 Iteration: 5000 Avg. Training loss: 3.8818 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 5100 Avg. Training loss: 3.8101 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 5200 Avg. Training loss: 3.7814 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 5300 Avg. Training loss: 3.8813 0.0031 sec/batch\n",
      "Epoch 1/2 Iteration: 5400 Avg. Training loss: 3.8584 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 5500 Avg. Training loss: 3.8904 0.0031 sec/batch\n",
      "Epoch 1/2 Iteration: 5600 Avg. Training loss: 3.8557 0.0032 sec/batch\n",
      "Epoch 1/2 Iteration: 5700 Avg. Training loss: 3.8679 0.0036 sec/batch\n",
      "Epoch 1/2 Iteration: 5800 Avg. Training loss: 3.8873 0.0029 sec/batch\n",
      "Epoch 1/2 Iteration: 5900 Avg. Training loss: 3.8221 0.0031 sec/batch\n",
      "Epoch 1/2 Iteration: 6000 Avg. Training loss: 3.8422 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 6100 Avg. Training loss: 3.8211 0.0033 sec/batch\n",
      "Epoch 1/2 Iteration: 6200 Avg. Training loss: 3.8596 0.0034 sec/batch\n",
      "Epoch 1/2 Iteration: 6300 Avg. Training loss: 3.8377 0.0032 sec/batch\n",
      "Epoch 1/2 Iteration: 6400 Avg. Training loss: 3.8674 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 6500 Avg. Training loss: 3.8852 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 6600 Avg. Training loss: 3.8489 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 6700 Avg. Training loss: 3.7788 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 6800 Avg. Training loss: 3.7647 0.0031 sec/batch\n",
      "Epoch 1/2 Iteration: 6900 Avg. Training loss: 3.8456 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 7000 Avg. Training loss: 3.8145 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 7100 Avg. Training loss: 3.6880 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 7200 Avg. Training loss: 3.8628 0.0031 sec/batch\n",
      "Epoch 1/2 Iteration: 7300 Avg. Training loss: 3.7836 0.0035 sec/batch\n",
      "Epoch 1/2 Iteration: 7400 Avg. Training loss: 3.7665 0.0030 sec/batch\n",
      "Epoch 1/2 Iteration: 7500 Avg. Training loss: 3.8244 0.0032 sec/batch\n",
      "Epoch 1/2 Iteration: 7600 Avg. Training loss: 3.7869 0.0033 sec/batch\n",
      "Epoch 1/2 Iteration: 7700 Avg. Training loss: 3.8586 0.0031 sec/batch\n",
      "Epoch 1/2 Iteration: 7800 Avg. Training loss: 3.8140 0.0032 sec/batch\n",
      "Epoch 1/2 Iteration: 7900 Avg. Training loss: 3.8063 0.0031 sec/batch\n",
      "Epoch 1/2 Iteration: 8000 Avg. Training loss: 3.7666 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 8100 Avg. Training loss: 3.8242 0.0011 sec/batch\n",
      "Epoch 2/2 Iteration: 8200 Avg. Training loss: 3.7972 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 8300 Avg. Training loss: 3.8070 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 8400 Avg. Training loss: 3.7800 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 8500 Avg. Training loss: 3.7664 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 8600 Avg. Training loss: 3.7972 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 8700 Avg. Training loss: 3.7772 0.0032 sec/batch\n",
      "Epoch 2/2 Iteration: 8800 Avg. Training loss: 3.8171 0.0034 sec/batch\n",
      "Epoch 2/2 Iteration: 8900 Avg. Training loss: 3.5283 0.0037 sec/batch\n",
      "Epoch 2/2 Iteration: 9000 Avg. Training loss: 3.6785 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 9100 Avg. Training loss: 3.7887 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 9200 Avg. Training loss: 3.6477 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 9300 Avg. Training loss: 3.7627 0.0032 sec/batch\n",
      "Epoch 2/2 Iteration: 9400 Avg. Training loss: 3.7004 0.0033 sec/batch\n",
      "Epoch 2/2 Iteration: 9500 Avg. Training loss: 3.7223 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 9600 Avg. Training loss: 3.7813 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 9700 Avg. Training loss: 3.7623 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 9800 Avg. Training loss: 3.7664 0.0033 sec/batch\n",
      "Epoch 2/2 Iteration: 9900 Avg. Training loss: 3.7830 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 10000 Avg. Training loss: 3.5724 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 10100 Avg. Training loss: 3.7872 0.0033 sec/batch\n",
      "Epoch 2/2 Iteration: 10200 Avg. Training loss: 3.6990 0.0035 sec/batch\n",
      "Epoch 2/2 Iteration: 10300 Avg. Training loss: 3.7388 0.0033 sec/batch\n",
      "Epoch 2/2 Iteration: 10400 Avg. Training loss: 3.7604 0.0033 sec/batch\n",
      "Epoch 2/2 Iteration: 10500 Avg. Training loss: 3.7788 0.0032 sec/batch\n",
      "Epoch 2/2 Iteration: 10600 Avg. Training loss: 3.7811 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 10700 Avg. Training loss: 3.7804 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 10800 Avg. Training loss: 3.7225 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 10900 Avg. Training loss: 3.7594 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 11000 Avg. Training loss: 3.7701 0.0032 sec/batch\n",
      "Epoch 2/2 Iteration: 11100 Avg. Training loss: 3.5425 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 11200 Avg. Training loss: 3.6653 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 11300 Avg. Training loss: 3.7724 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 11400 Avg. Training loss: 3.7315 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 11500 Avg. Training loss: 3.7700 0.0033 sec/batch\n",
      "Epoch 2/2 Iteration: 11600 Avg. Training loss: 3.7370 0.0036 sec/batch\n",
      "Epoch 2/2 Iteration: 11700 Avg. Training loss: 3.7636 0.0032 sec/batch\n",
      "Epoch 2/2 Iteration: 11800 Avg. Training loss: 3.6846 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 11900 Avg. Training loss: 3.7136 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 12000 Avg. Training loss: 3.6141 0.0032 sec/batch\n",
      "Epoch 2/2 Iteration: 12100 Avg. Training loss: 3.7296 0.0032 sec/batch\n",
      "Epoch 2/2 Iteration: 12200 Avg. Training loss: 3.7400 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 12300 Avg. Training loss: 3.7059 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 12400 Avg. Training loss: 3.7164 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 12500 Avg. Training loss: 3.7485 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 12600 Avg. Training loss: 3.6792 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 12700 Avg. Training loss: 3.7730 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 12800 Avg. Training loss: 3.7024 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 12900 Avg. Training loss: 3.7282 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 13000 Avg. Training loss: 3.6928 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 13100 Avg. Training loss: 3.7280 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 13200 Avg. Training loss: 3.7224 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 13300 Avg. Training loss: 3.5519 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 13400 Avg. Training loss: 3.7726 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 13500 Avg. Training loss: 3.6992 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 13600 Avg. Training loss: 3.7404 0.0032 sec/batch\n",
      "Epoch 2/2 Iteration: 13700 Avg. Training loss: 3.6699 0.0036 sec/batch\n",
      "Epoch 2/2 Iteration: 13800 Avg. Training loss: 3.7735 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 13900 Avg. Training loss: 3.7189 0.0036 sec/batch\n",
      "Epoch 2/2 Iteration: 14000 Avg. Training loss: 3.7348 0.0034 sec/batch\n",
      "Epoch 2/2 Iteration: 14100 Avg. Training loss: 3.6823 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 14200 Avg. Training loss: 3.7479 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 14300 Avg. Training loss: 3.6989 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 14400 Avg. Training loss: 3.7429 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 14500 Avg. Training loss: 3.7509 0.0032 sec/batch\n",
      "Epoch 2/2 Iteration: 14600 Avg. Training loss: 3.7420 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 14700 Avg. Training loss: 3.7245 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 14800 Avg. Training loss: 3.6008 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 14900 Avg. Training loss: 3.7295 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 15000 Avg. Training loss: 3.7446 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 15100 Avg. Training loss: 3.5852 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 15200 Avg. Training loss: 3.6885 0.0034 sec/batch\n",
      "Epoch 2/2 Iteration: 15300 Avg. Training loss: 3.7128 0.0038 sec/batch\n",
      "Epoch 2/2 Iteration: 15400 Avg. Training loss: 3.6527 0.0033 sec/batch\n",
      "Epoch 2/2 Iteration: 15500 Avg. Training loss: 3.6493 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 15600 Avg. Training loss: 3.7307 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 15700 Avg. Training loss: 3.7327 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 15800 Avg. Training loss: 3.7627 0.0031 sec/batch\n",
      "Epoch 2/2 Iteration: 15900 Avg. Training loss: 3.6855 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 16000 Avg. Training loss: 3.7025 0.0030 sec/batch\n",
      "Epoch 2/2 Iteration: 16100 Avg. Training loss: 3.6899 0.0031 sec/batch\n",
      "WARNING:tensorflow:From <ipython-input-107-ff24793480a7>:32: calling reduce_sum_v1 (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "keep_dims is deprecated, use keepdims instead\n"
     ]
    }
   ],
   "source": [
    "epochs = 2\n",
    "batch_size = 128\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    iteration = 1\n",
    "    loss = 0\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "\n",
    "    for e in range(1, epochs+1):\n",
    "        start = time.time()\n",
    "        for idx in range(0, len(x_batches), batch_size):\n",
    "            x = x_batches[idx:idx+batch_size]\n",
    "            y = y_batches[idx:idx+batch_size]\n",
    "            \n",
    "            feed = {inputs: x,\n",
    "                    labels: np.array(y)[:, None]}\n",
    "            train_loss, _ = sess.run([cost, optimizer], feed_dict=feed)\n",
    "            \n",
    "            loss += train_loss\n",
    "            \n",
    "            if iteration % 100 == 0: \n",
    "                end = time.time()\n",
    "                print(\"Epoch {}/{}\".format(e, epochs),\n",
    "                      \"Iteration: {}\".format(iteration),\n",
    "                      \"Avg. Training loss: {:.4f}\".format(loss/100),\n",
    "                      \"{:.4f} sec/batch\".format((end-start)/100))\n",
    "                loss = 0\n",
    "                start = time.time()\n",
    "            \n",
    "            iteration += 1\n",
    "            \n",
    "    norm = tf.sqrt(tf.reduce_sum(tf.square(embedding), 1, keep_dims=True))\n",
    "    normalized_embedding = embedding / norm\n",
    "    embed_mat = sess.run(normalized_embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [],
   "source": [
    "sims = np.dot(embed_mat, embed_mat.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.21077634,  0.25182042, -0.24591663, -0.20395333,  0.0882552 ,\n",
       "        0.05960291,  0.06305382, -0.03513731,  0.03025971, -0.31243148,\n",
       "       -0.19897626,  0.3326856 , -0.08734539,  0.00362529, -0.07794005,\n",
       "       -0.00431162,  0.0969455 ,  0.27191335,  0.24857497, -0.09293828,\n",
       "       -0.06077765,  0.0853738 ,  0.18494187,  0.0061985 ,  0.13451672,\n",
       "        0.22502415, -0.15815932,  0.27767122,  0.3023177 , -0.02005056,\n",
       "       -0.22100852, -0.02289582], dtype=float32)"
      ]
     },
     "execution_count": 117,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embed_mat[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.4 64-bit ('cvxpyenv': conda)",
   "language": "python",
   "name": "python37464bitcvxpyenvconda7109fb22449841ac9c9cd08247340af1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
